/*
 * Minio Cloud Storage, (C) 2017 Minio, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package main

import (
	"bufio"
	"encoding/hex"
	"flag"
	"fmt"
	"log"
	"os"
	"os/exec"
	"regexp"
	"strings"
)

var (
	assembleFlag = flag.Bool("a", false, "Immediately invoke asm2plan9s")
	stripFlag    = flag.Bool("s", false, "Strip comments")
	compactFlag  = flag.Bool("c", false, "Compact byte codes")
	formatFlag   = flag.Bool("f", false, "Format using asmfmt")
	targetFlag   = flag.String("t", "x86", "Target machine of input code")
)

// readLines reads a whole file into memory
// and returns a slice of its lines.
func readLines(path string) ([]string, error) {
	file, err := os.Open(path)
	if err != nil {
		return nil, err
	}
	defer file.Close()

	var lines []string
	scanner := bufio.NewScanner(file)
	for scanner.Scan() {
		lines = append(lines, scanner.Text())
	}
	return lines, scanner.Err()
}

// writeLines writes the lines to the given file.
func writeLines(lines []string, path string, header bool) error {
	file, err := os.Create(path)
	if err != nil {
		return err
	}
	defer file.Close()

	w := bufio.NewWriter(file)
	if header {
		fmt.Fprintln(w, "//+build !noasm !appengine")
		fmt.Fprintln(w, "// AUTO-GENERATED BY C2GOASM -- DO NOT EDIT")
		fmt.Fprintln(w, "")
	}
	for _, line := range lines {
		fmt.Fprintln(w, line)
	}
	return w.Flush()
}

func process(assembly []string, goCompanionFile string) ([]string, error) {

	// Split out the assembly source into subroutines
	subroutines := segmentSource(assembly)
	tables := segmentConstTables(assembly)

	var result []string

	// Iterate over all subroutines
	for isubroutine, sub := range subroutines {

		golangArgs, golangReturns := parseCompanionFile(goCompanionFile, sub.name)
		stackArgs := argumentsOnStack(sub.body)
		if len(golangArgs) > 6 && len(golangArgs)-6 < stackArgs.Number {
			panic(fmt.Sprintf("Found too few arguments on stack (%d) but needed %d", len(golangArgs)-6, stackArgs.Number))
		}

		// Check for constants table
		if table := getCorrespondingTable(sub.body, tables); table.isPresent() {

			// Output constants table
			result = append(result, strings.Split(table.Constants, "\n")...)
			result = append(result, "") // append empty line

			sub.table = table
		}

		// Create object to get offsets for stack pointer
		stack := NewStack(sub.epilogue, len(golangArgs), scanBodyForCalls(sub))

		// Write header for subroutine in go assembly
		result = append(result, writeGoasmPrologue(sub, stack, golangArgs, golangReturns)...)

		// Write body of code
		assembly, err := writeGoasmBody(sub, stack, stackArgs, golangArgs, golangReturns)
		if err != nil {
			panic(fmt.Sprintf("writeGoasmBody: %v", err))
		}
		result = append(result, assembly...)

		if isubroutine < len(subroutines)-1 {
			// Empty lines before next subroutine
			result = append(result, "\n", "\n")
		}
	}

	return result, nil
}

func stripGoasmComments(file string) {

	lines, err := readLines(file)
	if err != nil {
		log.Fatalf("readLines: %s", err)
	}

	for i, l := range lines {
		if strings.Contains(l, "LONG") || strings.Contains(l, "WORD") || strings.Contains(l, "BYTE") {
			opcode := strings.TrimSpace(strings.SplitN(l, "//", 2)[0])
			lines[i] = strings.SplitN(l, opcode, 2)[0] + opcode
		}
	}

	err = writeLines(lines, file, false)
	if err != nil {
		log.Fatalf("writeLines: %s", err)
	}
}

func reverseBytes(hex string) string {

	result := ""
	for i := len(hex) - 2; i >= 0; i -= 2 {
		result = result + hex[i:i+2]
	}
	return result
}

func compactArray(opcodes []byte) []string {

	var result []string

	dst := make([]byte, hex.EncodedLen(len(opcodes)))
	hex.Encode(dst, opcodes)

	q := 0
	for ; q+31 < len(dst); q += 32 {
		result = append(result, fmt.Sprintf("    QUAD $0x%s; QUAD $0x%s", reverseBytes(string(dst[q:q+16])), reverseBytes(string(dst[q+16:q+32]))))
	}
	for ; q+15 < len(dst); q += 16 {
		result = append(result, fmt.Sprintf("    QUAD $0x%s", reverseBytes(string(dst[q:q+16]))))
	}
	if q < len(dst) {
		last := ""
		l := 0
		if q+7 < len(dst) {
			last += fmt.Sprintf("LONG $0x%s", reverseBytes(string(dst[q:q+8])))
			l = 8
		}
		w := 0
		if q+l+3 < len(dst) {
			if len(last) > 0 {
				last = last + "; "
			}
			last += fmt.Sprintf("WORD $0x%s", reverseBytes(string(dst[q+l:q+l+4])))
			w = 4
		}
		if q+l+w+1 < len(dst) {
			if len(last) > 0 {
				last = last + "; "
			}
			last += fmt.Sprintf("BYTE $0x%s", dst[q+l+w:q+l+w+2])
		}
		result = append(result, "    "+last)
	}

	return result
}

func compactOpcodes(file string) {

	lines, err := readLines(file)
	if err != nil {
		log.Fatalf("readLines: %s", err)
	}

	var result []string

	opcodes := make([]byte, 0, 1000)

	hexMatch := regexp.MustCompile(`(\$0x[0-9a-f]+)`)

	for _, l := range lines {
		if strings.Contains(l, "LONG") || strings.Contains(l, "WORD") || strings.Contains(l, "BYTE") {
			match := hexMatch.FindAllStringSubmatch(l, -1)
			for _, m := range match {
				dst := make([]byte, hex.DecodedLen(len(m[0][3:])))
				_, err := hex.Decode(dst, []byte(m[0][3:]))
				if err != nil {
					log.Fatal(err)
				}
				for i := len(dst) - 1; i >= 0; i -= 1 { // append starting with lowest byte first
					opcodes = append(opcodes, dst[i:i+1]...)
				}
			}
		} else {

			if len(opcodes) != 0 {
				result = append(result, compactArray(opcodes)...)
				opcodes = opcodes[:0]
			}

			result = append(result, l)
		}
	}

	err = writeLines(result, file, false)
	if err != nil {
		log.Fatalf("writeLines: %s", err)
	}
}

func main() {

	flag.Parse()

	if flag.NArg() < 2 {
		fmt.Printf("error: not enough input files specified\n\n")
		fmt.Println("usage: c2goasm /path/to/c-project/build/SomeGreatCode.cpp.s SomeGreatCode_amd64.s")
		return
	}
	assemblyFile := flag.Arg(1)
	if !strings.HasSuffix(assemblyFile, ".s") {
		fmt.Printf("error: second parameter must have '.s' extension\n")
		return
	}

	goCompanion := assemblyFile[:len(assemblyFile)-2] + ".go"
	if _, err := os.Stat(goCompanion); os.IsNotExist(err) {
		fmt.Printf("error: companion '.go' file is missing for %s\n", flag.Arg(1))
		return
	}

	fmt.Println("Processing", flag.Arg(0))
	lines, err := readLines(flag.Arg(0))
	if err != nil {
		log.Fatalf("readLines: %s", err)
	}

	result, err := process(lines, goCompanion)
	if err != nil {
		fmt.Print(err)
		os.Exit(-1)
	}

	err = writeLines(result, assemblyFile, true)
	if err != nil {
		log.Fatalf("writeLines: %s", err)
	}

	if *assembleFlag {
		fmt.Println("Invoking asm2plan9s on", assemblyFile)
		cmd := exec.Command("asm2plan9s", assemblyFile)
		_, err := cmd.CombinedOutput()
		if err != nil {
			log.Fatalf("asm2plan9s: %v", err)
		}
	}

	if *stripFlag {
		stripGoasmComments(assemblyFile)
	}

	if *compactFlag {
		compactOpcodes(assemblyFile)
	}

	if *formatFlag {
		cmd := exec.Command("asmfmt", "-w", assemblyFile)
		_, err := cmd.CombinedOutput()
		if err != nil {
			log.Fatalf("asmfmt: %v", err)
		}
	}
}
